#!/usr/bin/env python3
"""
Phase 3: Cross-Paper Leave-One-Region-Out Jackknife
Tests 4 representative findings for sensitivity to regional composition:
  1. Trilemma peg-vs-float logit (collapsed)
  2. Crises banking old_dep (sign-reversed)
  3. Net/gross income balance (robust)
  4. Automation I/Y (robust)

CV > 30% = fragile.
"""

import sys
import pandas as pd
import numpy as np
from pathlib import Path
from scipy import stats

PROJECT_DIR = Path("/mnt/c/demographics_capital_flows")
sys.path.insert(0, str(PROJECT_DIR / "multilateral" / "src"))
from model import PanelGLS

FOLLOWUP_DIR = PROJECT_DIR / "multilateral" / "followup"
TABLE_DIR = PROJECT_DIR / "fragility" / "output" / "tables"
TABLE_DIR.mkdir(parents=True, exist_ok=True)

# ── Load panel ─────────────────────────────────────────────────────────────
print("=" * 70)
print("PHASE 3: CROSS-PAPER LEAVE-ONE-REGION-OUT JACKKNIFE")
print("=" * 70)

panel = pd.read_csv(FOLLOWUP_DIR / "data" / "processed" / "full_panel.csv")
panel = panel[(panel['year'] >= 1986) & (panel['year'] <= 2024)].copy()
print(f"Panel: {panel['iso3'].nunique()} countries, {len(panel):,} obs")

# ── Region classification ─────────────────────────────────────────────────
REGION_MAP = {
    'East Asia & Pacific': {
        'AUS', 'BRN', 'CHN', 'FJI', 'HKG', 'IDN', 'JPN', 'KHM', 'KIR', 'KOR',
        'LAO', 'MAC', 'MHL', 'MMR', 'MNG', 'MYS', 'NRU', 'NZL', 'PHL', 'PLW',
        'PNG', 'SGP', 'SLB', 'THA', 'TLS', 'TON', 'TUV', 'TWN', 'VNM', 'VUT', 'WSM'
    },
    'Europe & Central Asia': {
        'ALB', 'AND', 'ARM', 'AUT', 'AZE', 'BEL', 'BGR', 'BIH', 'BLR', 'CHE',
        'CYP', 'CZE', 'DEU', 'DNK', 'ESP', 'EST', 'FIN', 'FRA', 'GBR', 'GEO',
        'GRC', 'HRV', 'HUN', 'IRL', 'ISL', 'ITA', 'KAZ', 'KGZ', 'LTU', 'LUX',
        'LVA', 'MDA', 'MKD', 'MLT', 'MNE', 'NLD', 'NOR', 'POL', 'PRT', 'ROU',
        'RUS', 'SMR', 'SRB', 'SVK', 'SVN', 'SWE', 'TJK', 'TKM', 'TUR', 'UKR',
        'UZB', 'XKX'
    },
    'Latin America & Caribbean': {
        'ABW', 'ARG', 'ATG', 'BHS', 'BLZ', 'BOL', 'BRA', 'BRB', 'CHL', 'COL',
        'CRI', 'CUB', 'CUW', 'DMA', 'DOM', 'ECU', 'GRD', 'GTM', 'GUY', 'HND',
        'HTI', 'JAM', 'KNA', 'LCA', 'MEX', 'NIC', 'PAN', 'PER', 'PRY', 'SLV',
        'SUR', 'SXM', 'TCA', 'TTO', 'URY', 'VCT', 'VEN'
    },
    'Middle East & North Africa': {
        'ARE', 'BHR', 'DJI', 'DZA', 'EGY', 'IRN', 'IRQ', 'ISR', 'JOR', 'KWT',
        'LBN', 'LBY', 'MAR', 'OMN', 'PSE', 'QAT', 'SAU', 'SYR', 'TUN', 'YEM'
    },
    'South Asia': {
        'AFG', 'BGD', 'BTN', 'IND', 'LKA', 'MDV', 'NPL', 'PAK'
    },
    'Sub-Saharan Africa': {
        'AGO', 'BDI', 'BEN', 'BFA', 'BWA', 'CAF', 'CIV', 'CMR', 'COD', 'COG',
        'COM', 'CPV', 'ERI', 'ETH', 'GAB', 'GHA', 'GIN', 'GMB', 'GNB', 'GNQ',
        'KEN', 'LBR', 'LSO', 'MDG', 'MLI', 'MOZ', 'MRT', 'MUS', 'MWI', 'NAM',
        'NER', 'NGA', 'RWA', 'SEN', 'SLE', 'SOM', 'SSD', 'STP', 'SWZ', 'SYC',
        'TCD', 'TGO', 'TZA', 'UGA', 'ZAF', 'ZMB', 'ZWE'
    },
    'North America': {'CAN', 'USA'},
}

def get_region(iso3):
    for region, countries in REGION_MAP.items():
        if iso3 in countries:
            return region
    return 'Other'

panel['region'] = panel['iso3'].apply(get_region)
regions = [r for r in panel['region'].unique() if r != 'Other']
print(f"Regions: {regions}")
for r in regions:
    n = panel[panel['region'] == r]['iso3'].nunique()
    print(f"  {r}: {n} countries")


def run_jackknife(dep_var, indep_vars, data, label):
    """
    Run leave-one-region-out jackknife for a given specification.
    Returns DataFrame of results with CV for Z₁.
    """
    # Full sample
    comp = data.dropna(subset=[dep_var] + indep_vars).copy()
    if len(comp) < 50:
        print(f"  SKIP {label}: only {len(comp)} obs")
        return None

    gls = PanelGLS()
    gls.fit(comp[dep_var].values, comp[indep_vars].values,
            comp['iso3'].values, comp['year'].values)

    z1_idx = indep_vars.index('Z_1') if 'Z_1' in indep_vars else 0
    full_z1 = gls.beta[z1_idx]
    full_p = gls.pvalues[z1_idx]
    full_n = gls.n_obs

    rows = [{
        'label': label,
        'dropped_region': 'None (full)',
        'N': full_n,
        'countries': comp['iso3'].nunique(),
        'Z1_coef': full_z1,
        'Z1_se': gls.se[z1_idx],
        'Z1_p': full_p,
        'R2': gls.r_squared,
    }]

    # Leave-one-region-out
    z1_vals = []
    for region in regions:
        sub = comp[comp['region'] != region].copy()
        if sub['iso3'].nunique() < 10:
            continue
        gls_j = PanelGLS()
        gls_j.fit(sub[dep_var].values, sub[indep_vars].values,
                  sub['iso3'].values, sub['year'].values)

        z1_j = gls_j.beta[z1_idx]
        z1_vals.append(z1_j)

        rows.append({
            'label': label,
            'dropped_region': region,
            'N': gls_j.n_obs,
            'countries': sub['iso3'].nunique(),
            'Z1_coef': z1_j,
            'Z1_se': gls_j.se[z1_idx],
            'Z1_p': gls_j.pvalues[z1_idx],
            'R2': gls_j.r_squared,
        })

    df = pd.DataFrame(rows)

    # Compute CV
    if len(z1_vals) > 1 and np.mean(z1_vals) != 0:
        cv = np.std(z1_vals) / abs(np.mean(z1_vals)) * 100
    else:
        cv = np.nan

    df['CV_pct'] = cv
    df['fragile'] = cv > 30 if pd.notna(cv) else None

    return df


# ════════════════════════════════════════════════════════════════════════════
# TEST 1: Trilemma — Z₁ → CA/GDP (using extended baseline as proxy for peg)
# The peg logit collapsed 13×; test CA conditionality instead
# ════════════════════════════════════════════════════════════════════════════
print("\n" + "=" * 70)
print("TEST 1: TRILEMMA — Z₁ → CA/GDP (extended baseline)")
print("=" * 70)

# Load trilemma panel if available, else use followup panel
tri_path = PROJECT_DIR / "trilemma" / "data" / "processed" / "trilemma_panel.csv"
if tri_path.exists():
    tri = pd.read_csv(tri_path)
    tri = tri[(tri['year'] >= 1986) & (tri['year'] <= 2024)].copy()
    tri['region'] = tri['iso3'].apply(get_region)
    # ERS as proxy for regime choice
    tri_vars = ['Z_1', 'Z_2', 'Z_3', 'fiscal_bal_gdp', 'kaopen']
    avail_vars = [v for v in tri_vars if v in tri.columns]
    if 'ers_index' in tri.columns:
        avail_vars.append('ers_index')
    jk1 = run_jackknife('ca_gdp', avail_vars, tri, 'Trilemma: Z→CA')
else:
    # Fallback: use followup panel
    ca_vars = ['Z_1', 'Z_2', 'Z_3', 'fiscal_bal_gdp', 'kaopen', 'nfa_gdp_lag', 'log_rel_opw']
    avail_vars = [v for v in ca_vars if v in panel.columns]
    jk1 = run_jackknife('ca_gdp', avail_vars, panel, 'Trilemma: Z→CA')

if jk1 is not None:
    print(jk1.to_string(index=False))
    print(f"\nCV = {jk1['CV_pct'].iloc[0]:.1f}%  {'FRAGILE' if jk1['fragile'].iloc[0] else 'STABLE'}")


# ════════════════════════════════════════════════════════════════════════════
# TEST 2: Crises — old_dep → banking crisis (sign-reversed)
# ════════════════════════════════════════════════════════════════════════════
print("\n" + "=" * 70)
print("TEST 2: CRISES — old_dep → banking crisis")
print("=" * 70)

# Need banking crisis indicator — check crises panel
crises_path = PROJECT_DIR / "crises" / "data" / "processed"
crises_files = list(crises_path.glob("*.csv")) if crises_path.exists() else []

if crises_files:
    # Try to load the crises panel
    for cf in crises_files:
        try:
            cpanel = pd.read_csv(cf)
            if 'banking_crisis' in cpanel.columns or 'crisis_banking' in cpanel.columns:
                bank_col = 'banking_crisis' if 'banking_crisis' in cpanel.columns else 'crisis_banking'
                cpanel['region'] = cpanel['iso3'].apply(get_region)
                cpanel = cpanel[(cpanel['year'] >= 1986) & (cpanel['year'] <= 2024)].copy()
                bank_vars = [v for v in ['old_dep', 'youth_dep', 'kaopen', 'rgdp_growth'] if v in cpanel.columns]
                jk2 = run_jackknife(bank_col, bank_vars, cpanel, 'Crises: old_dep→banking')
                if jk2 is not None:
                    print(jk2.to_string(index=False))
                    print(f"\nCV = {jk2['CV_pct'].iloc[0]:.1f}%")
                break
        except Exception as e:
            continue
else:
    # Fallback: use old_dep on CA as proxy
    print("  No crises panel found; using old_dep→CA as proxy")
    ca_vars2 = ['old_dep', 'youth_dep', 'fiscal_bal_gdp', 'kaopen']
    avail2 = [v for v in ca_vars2 if v in panel.columns]
    jk2 = run_jackknife('ca_gdp', avail2, panel, 'Crises proxy: old_dep→CA')
    if jk2 is not None:
        print(jk2.to_string(index=False))
        print(f"\nCV = {jk2['CV_pct'].iloc[0]:.1f}%")


# ════════════════════════════════════════════════════════════════════════════
# TEST 3: Net/Gross — Z₁ → income balance
# ════════════════════════════════════════════════════════════════════════════
print("\n" + "=" * 70)
print("TEST 3: NET/GROSS — Z₁ → income balance")
print("=" * 70)

# Check for income_balance variable
ng_path = PROJECT_DIR / "net_gross" / "data" / "processed"
ng_files = list(ng_path.glob("*.csv")) if ng_path.exists() else []

# Try net_gross panel
income_bal_done = False
if ng_files:
    for nf in ng_files:
        try:
            ngpanel = pd.read_csv(nf)
            if 'income_balance_gdp' in ngpanel.columns or 'income_bal_gdp' in ngpanel.columns:
                ib_col = 'income_balance_gdp' if 'income_balance_gdp' in ngpanel.columns else 'income_bal_gdp'
                ngpanel['region'] = ngpanel['iso3'].apply(get_region)
                ngpanel = ngpanel[(ngpanel['year'] >= 1986) & (ngpanel['year'] <= 2024)].copy()
                ib_vars = [v for v in ['Z_1', 'Z_2', 'Z_3', 'fiscal_bal_gdp', 'kaopen', 'nfa_gdp_lag'] if v in ngpanel.columns]
                jk3 = run_jackknife(ib_col, ib_vars, ngpanel, 'Net/Gross: Z→income_bal')
                if jk3 is not None:
                    print(jk3.to_string(index=False))
                    print(f"\nCV = {jk3['CV_pct'].iloc[0]:.1f}%  {'FRAGILE' if jk3['fragile'].iloc[0] else 'STABLE'}")
                    income_bal_done = True
                break
        except Exception:
            continue

if not income_bal_done:
    # Fallback: use ca_gdp as dependent var
    print("  No income balance variable found; using CA/GDP")
    ca_vars3 = ['Z_1', 'Z_2', 'Z_3', 'fiscal_bal_gdp', 'kaopen', 'nfa_gdp_lag', 'log_rel_opw']
    avail3 = [v for v in ca_vars3 if v in panel.columns]
    jk3 = run_jackknife('ca_gdp', avail3, panel, 'Net/Gross proxy: Z→CA')
    if jk3 is not None:
        print(jk3.to_string(index=False))
        print(f"\nCV = {jk3['CV_pct'].iloc[0]:.1f}%")


# ════════════════════════════════════════════════════════════════════════════
# TEST 4: Automation — Z₁ → I/Y (robust finding)
# ════════════════════════════════════════════════════════════════════════════
print("\n" + "=" * 70)
print("TEST 4: AUTOMATION — Z₁ → gross_investment_gdp")
print("=" * 70)

iy_col = None
for col in ['gross_investment_gdp', 'gross_fixed_investment_gdp']:
    if col in panel.columns:
        iy_col = col
        break

if iy_col:
    iy_vars = ['Z_1', 'Z_2', 'Z_3', 'fiscal_bal_gdp', 'kaopen', 'nfa_gdp_lag', 'log_rel_opw']
    avail4 = [v for v in iy_vars if v in panel.columns]
    jk4 = run_jackknife(iy_col, avail4, panel, f'Automation: Z→{iy_col}')
    if jk4 is not None:
        print(jk4.to_string(index=False))
        print(f"\nCV = {jk4['CV_pct'].iloc[0]:.1f}%  {'FRAGILE' if jk4['fragile'].iloc[0] else 'STABLE'}")
else:
    print("  No investment/GDP variable found")
    jk4 = None


# ════════════════════════════════════════════════════════════════════════════
# Combine and save
# ════════════════════════════════════════════════════════════════════════════
print("\n" + "=" * 70)
print("COMBINED JACKKNIFE RESULTS")
print("=" * 70)

all_jk = []
for jk in [jk1, jk2, jk3, jk4]:
    if jk is not None:
        all_jk.append(jk)

if all_jk:
    combined = pd.concat(all_jk, ignore_index=True)
    combined.to_csv(TABLE_DIR / "table3_jackknife_results.csv", index=False)

    # Summary table
    summary_rows = []
    for label in combined['label'].unique():
        sub = combined[combined['label'] == label]
        full_row = sub[sub['dropped_region'] == 'None (full)']
        jack_rows = sub[sub['dropped_region'] != 'None (full)']

        if len(full_row) > 0 and len(jack_rows) > 0:
            z1_full = full_row['Z1_coef'].values[0]
            z1_min = jack_rows['Z1_coef'].min()
            z1_max = jack_rows['Z1_coef'].max()
            cv = sub['CV_pct'].iloc[0]
            fragile = sub['fragile'].iloc[0]

            # Which region matters most?
            deviations = jack_rows.copy()
            deviations['deviation'] = abs(deviations['Z1_coef'] - z1_full)
            most_influential = deviations.loc[deviations['deviation'].idxmax(), 'dropped_region']

            summary_rows.append({
                'Test': label,
                'Z1_full': z1_full,
                'Z1_min': z1_min,
                'Z1_max': z1_max,
                'Range': z1_max - z1_min,
                'CV_pct': cv,
                'Fragile': fragile,
                'Most_influential_region': most_influential,
            })

    summary_df = pd.DataFrame(summary_rows)
    summary_df.to_csv(TABLE_DIR / "table3_jackknife_summary.csv", index=False)

    # Markdown
    md = ["# Table 3: Leave-One-Region-Out Jackknife\n"]
    md.append("CV > 30% indicates fragility to regional composition.\n")
    md.append("| Test | Z₁ (full) | Z₁ range | CV (%) | Fragile? | Most influential region |")
    md.append("|:---|---:|:---|---:|:---|:---|")
    for _, r in summary_df.iterrows():
        frag = "**YES**" if r['Fragile'] else "No"
        md.append(f"| {r['Test']} | {r['Z1_full']:.2f} | [{r['Z1_min']:.2f}, {r['Z1_max']:.2f}] | "
                  f"{r['CV_pct']:.1f} | {frag} | {r['Most_influential_region']} |")

    md.append(f"\n## Detailed Results\n")
    for label in combined['label'].unique():
        md.append(f"\n### {label}\n")
        sub = combined[combined['label'] == label]
        md.append("| Dropped region | N | Countries | Z₁ | SE | p-value | R² |")
        md.append("|:---|---:|---:|---:|---:|---:|---:|")
        for _, r in sub.iterrows():
            sig = '***' if r['Z1_p'] < 0.001 else '**' if r['Z1_p'] < 0.01 else '*' if r['Z1_p'] < 0.05 else ''
            md.append(f"| {r['dropped_region']} | {r['N']} | {r['countries']} | "
                      f"{r['Z1_coef']:.3f}{sig} | {r['Z1_se']:.3f} | {r['Z1_p']:.4f} | {r['R2']:.4f} |")

    md_text = "\n".join(md)
    (TABLE_DIR / "table3_jackknife_results.md").write_text(md_text)
    print(md_text)

print(f"\nSaved results to {TABLE_DIR}")
print("Phase 3 complete.")
